//
//  RoiPoolGradTest.cpp
//  MNNTests
//
//  Created by MNN on 2022/11/23.
//  Copyright © 2018, Alibaba Group Holding Limited
//

#include <MNN/expr/Expr.hpp>
#include <MNN/expr/ExprCreator.hpp>
#include "MNNTestSuite.h"
#include "TestUtils.h"
#include <vector>
#include "../tools/train/source/grad/OpGrad.hpp"

using namespace std;
using namespace MNN;
using namespace MNN::Express;

class RoiPoolGradTest : public MNNTestCase {
public:
    char name[20] = "RoiPool";
    virtual ~RoiPoolGradTest() = default;

    virtual bool run(int precision) {
        vector<int> inputShape = {3, 9, 4, 4};
        auto input = _Input(inputShape, NCHW);
        vector<float> inputData = {  1.2587,  0.8945, -0.7543,  2.0012,  1.3856,  0.9340, -0.1376,  0.1837,
                                    -0.8983, -0.6558, -0.3315, -0.1219,  0.3075,  1.4431,  0.4746, -0.5140,
                                    -0.1245, -0.3308,  0.0673,  0.7772,  0.0639, -0.6047,  0.1636, -0.3155,
                                    0.1523, -0.4973,  1.2428, -0.5529, -0.4270,  1.6364,  0.2426,  1.0559,
                                    0.9193, -0.3176,  0.0892, -0.2509, -2.4380, -0.9005,  0.5781,  1.4008,
                                    -0.5696,  0.7918, -2.0354, -1.2119, -1.3411, -0.3476, -0.0886, -0.0649,
                                    -0.7115, -0.5848, -1.0234, -1.5510,  0.0550, -1.4380,  0.7675, -0.5726,
                                    1.4670, -0.9947, -1.8784, -1.5924, -0.8265,  0.0891, -0.3028, -1.3168,
                                    -0.4735, -1.4944,  0.3030, -0.1165,  0.2647, -1.3413, -0.9787,  0.2888,
                                    2.8702, -1.2232, -0.8815,  1.9677,  0.5460, -0.8153, -0.5402, -0.7055,
                                    1.0696,  0.3802, -0.0949,  0.9391, -1.0831, -1.3940,  1.7714,  1.0268,
                                    1.0844, -1.3981,  0.4609, -0.7931, -0.3239,  1.2300,  1.4248, -1.2277,
                                    -1.5531, -0.3628,  0.3534,  1.1957, -0.5323, -0.5895, -1.6513,  0.8463,
                                    0.1348,  0.7655,  0.1805, -1.6148,  2.8097, -0.4605, -0.6506,  0.0297,
                                    0.6477, -1.1414, -1.0395,  0.0904,  1.5177, -0.0325,  0.5897,  0.3167,
                                    0.4292,  0.3140, -0.3639, -0.7091,  0.2055,  0.0503,  0.6292, -0.1367,
                                    0.7991, -1.3695, -0.7060, -0.0840,  0.3023,  0.4616, -0.7059, -1.6423,
                                    -1.3314,  1.2474,  0.5421,  1.4275, -0.8528, -0.6006,  0.2814,  0.6976,
                                    -0.6811, -1.9291,  1.2983,  0.4801, -1.1602,  0.4394, -0.3520,  0.6311,
                                    -0.3585,  0.1489,  0.9659,  0.5493, -0.9856, -1.1759,  0.9381,  0.5606,
                                    2.0309,  0.5102,  1.7770,  1.2903, -1.4298, -1.1124,  0.3458,  1.8255,
                                    0.5936,  1.3503,  0.9923, -0.9042,  0.0124,  1.0796, -0.2233, -1.0319,
                                    -1.3835,  0.8602,  0.0651, -0.1098, -1.9900, -1.1028, -1.0592,  0.8511,
                                    -0.2102, -0.7675, -0.6877, -0.4493, -0.5632,  0.3369,  0.5917, -0.6685,
                                    -1.1458, -1.9596,  1.8387, -0.2642,  1.4898, -0.7788,  0.7117, -1.2234,
                                    0.3939, -0.2793, -0.4268, -1.1598,  0.4164,  0.4359,  0.5211,  0.5965,
                                    2.0014,  0.7337,  0.3770, -0.8599,  0.7286, -0.9268,  0.1724, -1.1386,
                                    -0.1429,  0.7072,  0.5999, -0.2979, -0.8230, -0.8795, -0.5317,  0.0974,
                                    1.0004,  0.6322, -1.9103,  0.8706, -0.2598, -1.2323, -0.3205,  1.3420,
                                    0.3936,  2.0456,  1.7977, -1.1196, -0.5652, -1.3567,  0.9958, -2.0845,
                                    -1.3749, -0.7130,  1.0244,  0.0593,  0.0636,  0.2393, -1.3413, -0.3329,
                                    0.2147, -0.5064, -0.5119, -0.9965, -0.4002, -0.6242, -0.3976, -0.6084,
                                    1.2526,  0.7067, -0.2353, -0.5699,  0.0824,  1.0667, -0.0329, -0.1180,
                                    -1.6390,  0.7729,  0.7066, -2.2387, -0.7651, -0.4625,  1.9304, -0.1592,
                                    0.4796, -0.6125, -0.8265,  0.0568,  2.5158, -0.3929, -0.3927, -0.1145,
                                    0.4040,  0.2954, -0.7797,  0.0569,  0.3714,  0.5620, -0.6556,  0.0075,
                                    -0.0251, -1.6895, -0.8571,  0.3759,  0.0106,  1.0075,  1.3647, -0.1915,
                                    -0.1687, -1.9660,  0.7073, -1.0942, -0.1903,  1.2114,  0.6589,  0.7416,
                                    0.1255, -0.2084, -1.7247, -0.6163,  2.5999,  0.5725, -0.1817,  1.2373,
                                    1.6475, -0.3679, -0.7700,  0.5559, -0.0299, -0.9032, -1.6034,  0.8630,
                                    -0.9992, -0.5817,  0.5362,  1.4626, -0.7890, -1.0981,  1.7217,  0.7581,
                                    1.2861,  0.3955,  0.2466,  0.3384,  0.1506, -0.9613, -1.5495, -0.1552,
                                    -0.1180,  1.6468, -1.7706,  0.0055, -0.4989,  0.1550,  1.6259, -0.0722,
                                    0.7386,  0.0657,  0.5618, -0.4135, -1.3406,  0.7209, -0.1369, -0.4943,
                                    -0.3508, -0.2657,  1.5009,  1.8255,  0.7049,  0.9854,  0.3529,  1.1112,
                                    -0.9561,  0.7174, -1.1929, -1.2257,  0.1584,  0.7370, -0.7273,  0.8572,
                                    0.0591, -1.8631, -0.1637,  1.9188, -0.9281, -1.3265,  0.3382, -0.5424,
                                    1.4783, -0.0339,  0.8036,  0.1805,  1.2170, -2.1388, -0.4797,  1.1232,
                                    0.4213, -0.2824, -0.0592, -0.3094, -0.9494,  0.0946,  1.3795,  0.4063,
                                    0.1934, -0.2050,  1.1473, -1.8769, -1.3865,  0.0212,  0.5409,  0.8030,
                                    -2.4675,  0.6231,  0.1214, -1.3949, -1.0724, -1.5440, -0.4761,  1.4920,
                                    1.7146,  0.3414,  0.1379, -0.8818, -1.7559,  1.8605, -1.3545,  1.5024,
                                    -2.1116, -0.4113, -1.4620, -0.2642,  0.3996, -0.0468, -1.4671,  0.4811,
                                    0.0413,  0.4503, -0.2901,  1.9869, -0.3118, -1.3857,  1.3151, -0.7364};
        auto inputPtr          = input->writeMap<float>();
        memcpy(inputPtr, inputData.data(), input->getInfo()->size * sizeof(float));

        const float spatialScale = 1.0 / 16;
        const int pooledHeight = 3;
        const int pooledWidth = 3;

        auto roiInput = _Input({2, 5}, NCHW);
        vector<float> roiData = {   2, 1 / spatialScale, 2 / spatialScale, 3 / spatialScale, 3 / spatialScale,
                                    0, 0 / spatialScale, 2 / spatialScale, 2 / spatialScale, 3 / spatialScale};
        memcpy(roiInput->writeMap<float>(), roiData.data(), roiInput->getInfo()->size * sizeof(float));

        auto outputOri = _ROIPooling(_Convert(input, NC4HW4), _Convert(roiInput, NC4HW4), pooledHeight, pooledWidth, spatialScale);
        auto output = _Convert(outputOri, NCHW);
        auto outputPtr = output->readMap<float>();

        vector<float> outputTorch = {   -1.9660,  0.7073, -1.0942,  1.2114,  0.7073,  0.7416,  1.2114,  0.6589,
                                        0.7416, -0.3679, -0.7700,  0.5559, -0.3679, -0.7700,  0.8630, -0.9032,
                                        -1.6034,  0.8630,  0.3955,  0.2466,  0.3384,  0.3955,  0.2466,  0.3384,
                                        -0.9613, -1.5495, -0.1552,  0.0657,  0.5618, -0.4135,  0.7209,  0.5618,
                                        -0.4135,  0.7209, -0.1369, -0.4943,  0.7174, -1.1929, -1.2257,  0.7370,
                                        -0.7273,  0.8572,  0.7370, -0.7273,  0.8572, -0.0339,  0.8036,  0.1805,
                                        -0.0339,  0.8036,  1.1232, -2.1388, -0.4797,  1.1232, -0.2050,  1.1473,
                                        -1.8769,  0.0212,  1.1473,  0.8030,  0.0212,  0.5409,  0.8030,  0.3414,
                                        0.1379, -0.8818,  1.8605,  0.1379,  1.5024,  1.8605, -1.3545,  1.5024,
                                        0.4503, -0.2901,  1.9869,  0.4503,  1.3151,  1.9869, -1.3857,  1.3151,
                                        -0.7364, -0.8983, -0.6558, -0.3315,  0.3075,  1.4431,  0.4746,  0.3075,
                                        1.4431,  0.4746,  0.1523, -0.4973,  1.2428,  0.1523,  1.6364,  1.2428,
                                        -0.4270,  1.6364,  0.2426, -0.5696,  0.7918, -2.0354, -0.5696,  0.7918,
                                        -0.0886, -1.3411, -0.3476, -0.0886,  1.4670, -0.9947, -1.8784,  1.4670,
                                        0.0891, -0.3028, -0.8265,  0.0891, -0.3028,  2.8702, -1.2232, -0.8815,
                                        2.8702, -0.8153, -0.5402,  0.5460, -0.8153, -0.5402,  1.0844, -1.3981,
                                        0.4609,  1.0844,  1.2300,  1.4248, -0.3239,  1.2300,  1.4248,  0.1348,
                                        0.7655,  0.1805,  2.8097,  0.7655,  0.1805,  2.8097, -0.4605, -0.6506,
                                        0.4292,  0.3140, -0.3639,  0.4292,  0.3140,  0.6292,  0.2055,  0.0503,
                                        0.6292, -1.3314,  1.2474,  0.5421, -0.8528,  1.2474,  0.5421, -0.8528,
                                        -0.6006,  0.2814};

        for (int i = 0, count = 0; i < outputTorch.size(); ++i) {
            auto diff = ::fabsf(outputPtr[i] - outputTorch[i]);
            if (diff > 0.0001) {
                count++;
                MNN_ERROR("%d: %s output test failed, expected: %f, but got: %f!\n", count, name, outputTorch[i], outputPtr[i]);
                return false;
            } else {
                // MNN_PRINT("\toutput exact: %d, %f <==> %f\n", i, outputPtr[i], outputTorch[i]);
            }
        }

        auto opExpr = outputOri->expr().first;

        auto grad = OpGrad::get(opExpr->get()->type());
        vector<float> outputDiff = {0.2562, -0.0111, -1.5401,  0.5039, -0.6176, -0.1565,  1.8841, -0.3646,
                                    -2.0870,  2.0135, -0.6354, -0.6129,  1.3251,  1.6232, -0.9059, -0.1318,
                                    0.4667,  0.3912,  0.8568, -0.3556,  0.2248, -0.1303, -1.6850,  0.4877,
                                    -0.1433, -1.3551,  0.8345, -1.1855,  1.4541,  0.4225, -1.0868, -1.0298,
                                    0.8969, -0.3130,  0.5271, -0.8280, -0.6191,  0.5584, -1.8515,  0.6529,
                                    -1.1239,  0.8073,  0.3257,  2.0378, -0.7919,  0.8637,  1.3289,  1.1278,
                                    0.8832,  0.3839, -1.1428,  0.6202, -1.1006, -1.4295, -0.6698,  0.9958,
                                    -0.4719,  1.7213, -0.1548, -0.0358, -0.1978,  0.4558, -1.3004, -0.1816,
                                    0.4252,  0.9267, -0.9387,  1.0997,  1.2616,  1.7754, -0.5986,  0.4416,
                                    -0.2952, -0.8717,  0.1005, -0.3586,  1.5658, -1.1852,  0.9115, -0.5239,
                                    0.8183,  0.6974,  0.0715, -2.1861, -0.6542, -1.6065, -0.8234,  0.2259,
                                    0.5781, -0.6618, -0.1676,  1.8451,  0.5430, -0.9335, -0.1344, -1.1820,
                                    0.5422, -2.2710,  0.4764,  0.0155,  0.8077,  0.0861, -0.4085,  1.7200,
                                    0.0790, -0.2339, -0.0539,  0.4019, -0.2817, -0.3598, -1.2706,  0.2367,
                                    -0.8693, -1.2023,  1.0073,  1.4283,  0.0475, -2.9939, -0.6765, -0.9341,
                                    -0.5517, -0.9149,  0.2808, -0.4714,  0.4733,  0.5395,  0.3451,  0.2129,
                                    -1.2796, -0.2701,  2.2198,  1.2021,  0.5800,  0.5960,  1.4162, -1.1088,
                                    1.0207, -2.8389, -2.0961,  0.6317, -0.7982, -0.6363,  0.4101, -1.3526,
                                    0.9465, -0.1597,  0.9172, -1.0526,  1.3711,  0.7282,  1.0296,  0.3771,
                                    0.3484,  0.8067,  1.4168, -1.1119, -1.6325, -0.8257, -0.0769, -1.3230,
                                    1.1655, -0.4298};

        auto inputGrad = grad->onGrad(opExpr, {_Convert(_Const(outputDiff.data(), {2, 9, 3, 3}, NCHW), NC4HW4)});

        vector<float> expectedOutput = {    0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.6974,  0.0715, -2.1861,  0.0000, -0.4283, -1.0284, -1.4852,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            -1.1011,  1.8451, -0.6390,  0.0000,  0.5422, -2.4054,  0.4764,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            -0.3930,  2.5277,  0.0861,  0.0000, -0.2339, -0.0539,  0.4809,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            -0.0450, -0.3598, -1.2706,  0.0000,  1.0073,  0.5590, -1.1548,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            -3.5456, -0.6765, -0.9341,  0.0000, -0.4714, -0.4416,  0.8203,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0750,  0.2129, -1.2796,  0.0000,  0.5800,  2.8158,  2.6183,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            -1.1088,  1.6524, -3.6371,  0.0000, -2.7324,  0.4101, -1.3526,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            -0.1061,  1.2114,  0.9172,  0.0000,  1.0296,  0.3771,  1.0766,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.8067,  0.5911, -1.1888,  0.0000, -2.9555,  1.1655, -0.4298,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.2562, -0.6287, -1.5401,  0.0000,  2.3880, -0.3646, -2.2435,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  3.3386,  0.9878, -0.6129,  0.0000, -0.1318,  0.4667, -0.5147,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  0.7265, -2.0406,  0.7125,  0.0000, -0.1433, -1.3551,  0.8345,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000, -1.1855,  0.4243,  1.3194,  0.0000, -1.3998,  0.5271, -0.8280,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000, -0.6191,  0.5584, -1.8515,  0.0000,  0.9786,  0.9139,  0.0154,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000,  1.7469,  1.7128,  1.1278,  0.0000,  0.6202, -1.1006, -2.5723,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000, -0.6698,  0.8410, -0.4719,  0.0000,  1.5235,  0.4558, -1.3362,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000, -0.1816,  1.5249,  0.9267,  0.0000,  0.8367, -0.5986,  1.7032,
                                            0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
                                            0.0000, -0.6538, -0.8717, -1.0847,  0.0000,  0.9115,  1.0419,  0.8183};
        auto gotOutput = _Convert(inputGrad[0], NCHW)->readMap<float>();

        for (int i = 0; i < expectedOutput.size(); ++i) {
            auto diff = ::fabsf(gotOutput[i] - expectedOutput[i]);
            if (diff > 0.001) {
                MNN_ERROR("%s grad test failed, expected: %f, but got: %f!\n", name, expectedOutput[i], gotOutput[i]);
                return false;
            } else {
                // MNN_PRINT("%s grad exact, %f <==> %f\n", name, expectedOutput[i], gotOutput[i]);
            }
        }

        return true;
    }
};

MNNTestSuiteRegister(RoiPoolGradTest, "grad/roi_pool");
